### Function library for maximum likelihood estimation of stopping rules research

# loads all necessary theory specific functions i.e. wrapMLE_, log_lik_, stat_sum_
source("Replication_scripts/stopping_rules_func_library/MaxLike_2sQTR_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_CRRA_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_disapAversion_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_salience_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_CPT_3param_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_CPT_5param_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_regret_Aversion_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_RDU_prelec_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_RDU_GE_function_lib.R")
source("Replication_scripts/stopping_rules_func_library/MaxLike_RDU_KT_function_lib.R")

# Stopping rules individual estimation function plus wrapper functions that utilize the above theory specific functions

MaxLikeSR = function(data,sub_sample,util_fun,init, df_regret_probs=NULL, bootstrapping=FALSE){
  start = Sys.time()                                                               
  data = sample_SR(data,sub_sample)
  index = data$studentID                                                            
  data = split(data,f = index)                                                      
  wrapMLE = def_wrapMLE(util_fun)                                                   
  logLikFun = def_logLik(util_fun)                                                  
  stat_sum = def_stat_sum(util_fun)
  
  if(bootstrapping==FALSE){
    if(is.null(df_regret_probs)){ #checking if using regret model 
      Res <- lapply(data,wrapMLE,init,logLikFun)  
    }
    else{
      Res <- lapply(data,wrapMLE,init,logLikFun, df_regret_probs)  
    }
    Res <- do.call(rbind,lapply(Res,stat_sum))                          
    Res$studentID=unique(index)
  }
  else if(bootstrapping==TRUE){
    param_distributions <- lapply(data, bootstrap_estimation, wrapMLE, init, logLikFun)
    param_stats <- lapply(param_distributions, bootstrap_sum_stats)
    param_stats <- do.call(rbind, param_stats)
    
    Res <- lapply(data,wrapMLE,init,logLikFun)
    Res <- do.call(rbind,lapply(Res,stat_sum))                          
    Res$studentID=unique(index)
    
    Res = cbind(Res, param_stats)
  }
  
  end = Sys.time()                                                                  
  timer = difftime(end,start,units = "secs")
  print(paste("Time taken for MLE algorithim:",seconds_to_period(round(timer[[1]],2))))
  return(Res)
}


def_wrapMLE = function(util_fun){
  if(util_fun == "CRRA"){wrapMLE = wrapMLE_CRRA}
  if(util_fun == "disapAver"){wrapMLE = wrapMLE_disappointment}
  if(util_fun == "regretAver"){wrapMLE = wrapMLE_regret}
  if(util_fun == "CPT_3param"){wrapMLE = wrapMLE_CPT_3param}
  if(util_fun == "2sQTR"){wrapMLE = wrapMLE_2sQTR}
  if(util_fun == "salience"){wrapMLE = wrapMLE_salience}
  if(util_fun == "CPT_5param"){wrapMLE = wrapMLE_CPT_5param}
  if(util_fun == "RDU_prelec"){wrapMLE = wrapMLE_RDU_prelec}
  if(util_fun == "RDU_GE"){wrapMLE = wrapMLE_RDU_GE}
  if(util_fun == "RDU_KT"){wrapMLE = wrapMLE_RDU_KT}
  return(wrapMLE)
}

def_logLik = function(util_fun){
  if(util_fun == "CRRA"){logLikFun = log_lik_CRRA}
  if(util_fun == "disapAver"){logLikFun = log_lik_disapAver}
  if(util_fun == "regretAver"){logLikFun = log_lik_regretAver}
  if(util_fun == "CPT_3param"){logLikFun = log_lik_CPT_3param}
  if(util_fun == "2sQTR"){logLikFun = log_lik_2sQTR}
  if(util_fun == "salience"){logLikFun = log_lik_salience}
  if(util_fun == "CPT_5param"){logLikFun = log_lik_CPT_5param}
  if(util_fun == "RDU_prelec"){logLikFun = log_lik_RDU_prelec}
  if(util_fun == "RDU_GE"){logLikFun = log_lik_RDU_GE}
  if(util_fun == "RDU_KT"){logLikFun = log_lik_RDU_KT}
  return(logLikFun)
}

def_stat_sum = function(util_fun){
  if(util_fun == "CRRA"){stat_sum = stat_sum_CRRA}
  if(util_fun == "disapAver"){stat_sum = stat_sum_disapAver}
  if(util_fun == "regretAver"){stat_sum = stat_sum_regret}
  if(util_fun == "CPT_3param"){stat_sum = stat_sum_CPT_3param}
  if(util_fun == "2sQTR"){stat_sum = stat_sum_2sQTR}
  if(util_fun == "salience"){stat_sum = stat_sum_salience}
  if(util_fun == "CPT_5param"){stat_sum = stat_sum_CPT_5param}
  if(util_fun == "RDU_prelec"){stat_sum = stat_sum_RDU_prelec}
  if(util_fun == "RDU_GE"){stat_sum = stat_sum_RDU_GE}
  if(util_fun == "RDU_KT"){stat_sum = stat_sum_RDU_KT}
  return(stat_sum)
}


bootstrap_estimation <- function(data, wrapMLE, init, logLikFun){
  result = list(epsilon = list(), alpha = list(), beta = list())
  bootstrap_data_sets = list()
  for(i in 1:1000){
    bootstrap_data_sets[[i]] = slice_sample(data, n = 36, replace = TRUE)
  }
  bootstrap_results = lapply(bootstrap_data_sets, wrapMLE, init, logLikFun)
  result[['epsilon']] = lapply(bootstrap_results, get_bootstrap_epsilon)
  result[['alpha']] = lapply(bootstrap_results, get_bootstrap_alpha)
  result[['beta']] = lapply(bootstrap_results, get_bootstrap_beta)
  result
}

get_bootstrap_epsilon <- function(data){
  data@coef[["epsilon"]]
}

get_bootstrap_alpha<- function(data){
  data@coef[["alpha"]]
}

get_bootstrap_beta <- function(data){
  data@coef[["beta"]]
}


bootstrap_sum_stats <- function(data){
  df_epsilon = unlist(data[['epsilon']])
  df_alpha   = unlist(data[['alpha']])
  df_beta    = unlist(data[['beta']])
 
  t_test_epsilon <- try(t.test(df_epsilon), silent = TRUE)
  t_test_alpha   <- try(t.test(df_alpha, mu = 0.5), silent = TRUE) 
  t_test_beta    <- try(t.test(df_beta, mu = 0.5), silent = TRUE) 
  
  if(is(t_test_epsilon, "try-error")){
    mean_epsilon    = mean(df_epsilon)
    std_epsilon     = sqrt(var(df_epsilon))/sqrt(length(df_epsilon))
    p_val_epsilon   = 'NA'
    CI_95_epsilon   = 'NA'
  }
  else{
    mean_epsilon   = t_test_epsilon[["estimate"]][["mean of x"]]
    std_epsilon    = t_test_epsilon[["stderr"]]
    p_val_epsilon  = t_test_epsilon[["p.value"]]
    CI_95_epsilon  = paste0('(', round(t_test_epsilon[["conf.int"]][1],3), " , ",round(t_test_epsilon[["conf.int"]][2],3), ')')
  }
  
  if(is(t_test_alpha, "try-error")){
    mean_alpha    = mean(df_alpha)
    std_alpha     = sqrt(var(df_alpha))/sqrt(length(df_alpha))
    p_val_alpha   = 'NA'
    CI_95_alpha   = 'NA'
  }
  else{
    mean_alpha   = t_test_alpha[["estimate"]][["mean of x"]]
    std_alpha    = t_test_alpha[["stderr"]]
    p_val_alpha  = t_test_alpha[["p.value"]]
    CI_95_alpha  = paste0('(', round(t_test_alpha[["conf.int"]][1],3), " , ",round(t_test_alpha[["conf.int"]][2],3), ')')
  }
  
  if(is(t_test_beta, "try-error")){
    mean_beta    = mean(df_beta)
    std_beta     = sqrt(var(df_beta))/sqrt(length(df_beta))
    p_val_beta   = 'NA'
    CI_95_beta   = 'NA'
  }
  else{
    mean_beta   = t_test_beta[["estimate"]][["mean of x"]]
    std_beta    = t_test_beta[["stderr"]]
    p_val_beta  = t_test_beta[["p.value"]]
    CI_95_beta  = paste0('(', round(t_test_beta[["conf.int"]][1],3), " , ",round(t_test_beta[["conf.int"]][2],3), ')')
  }
  
  result = data.frame("mean_epsilon_boot"=mean_epsilon, "mean_alpha_boot"=mean_alpha, "mean_beta_boot"=mean_beta,
                      "std_epsilon_boot"=std_epsilon, "std_alpha_boot"=std_alpha, "std_beta_boot"=std_beta,
                      "p_val_epsilon_boot"=p_val_epsilon, "p_val_alpha_boot"=p_val_alpha,"p_val_beta_boot"=p_val_beta,
                      "CI_95_epsilon_boot"=CI_95_epsilon, "CI_95_alpha_boot"=CI_95_alpha, "CI_95_beta_boot"=CI_95_beta)
                  
                 
  result
}


sample_SR = function(data,sub_sample){
  if(sub_sample == "full"){data = data}
  if(sub_sample == "part1"){data = data %>% filter(qnum<19)}
  if(sub_sample == "part2"){data = data %>%  filter(qnum>18)}
  return(data)
}
